Let's write our own reverse-mode AD!
We will use Julia's dispatch system for simplicity. This means we create a type Tracked for keeping track of our input variables and everything we'll need to calculate the gradient later.
All overloads will do the operation (e.g. sum x and y), but also remember the pullback map and input variables for the reverse pass.
@λ (from the LegibleLambdas.jl package) is just for the nicer printing, we could have replaced @λ(Δ -> (Δ, Δ)) with Δ -> (Δ, Δ) if we didn't care about that
Tracked is a tree – We just need to tell AbstractTrees.jl how to get the children for each node and we get tree printing and iteration over all nodes for free.
Let's also overload show for nicer output:
Create some variables we want to eventually differentiate with respect to.
y=7
Straight away we get the primal result of our calculation:
29To also get the gradient, we'll use PreOrderDFS to traverse the tree we just created from the top down.
(29, +)
├─ (28, *)
│ ├─ (4, *)
│ │ ├─ 2
│ │ └─ x=2
│ └─ y=7
└─ (1, ^2)
└─ (1, -)
├─ x=2
└─ 1
(29, +)
(28, *)
(4, *)
2
x=2
y=7
(1, ^2)
(1, -)
x=2
1
Ok, let's create our function grad which will accumulate all intermediate gradients into a dictionary:
grad (generic function with 1 method)grad (generic function with 2 methods)We can verify that it does the right thing:
(16, +)
├─ (14, *)
│ ├─ x=2
│ └─ y=7
└─ x=2
x=2
8
y=7
2
(16, +) ├─ (14, *) │ ├─ x=2 │ └─ y=7 └─ x=2
1
(14, *) ├─ x=2 └─ y=7
1
8
2
How can we visualize both the forward and the reverse pass?
We can further visualize each steps we just took. First we do the forwards calculation, where we also build up our tree, then we go down the tree in the opposite direction to accumulate our gradient.
👽=12
:(x * exp(-0.5 * (x ^ 2 + y ^ 2)))ArgumentError: invalid index: Main.PlutoRunner.Bond(PlutoUI.BuiltinsNotebook.Slider{Int64}(1:17, 1, false), :i, "GIudtJmLQe9h") of type Main.PlutoRunner.Bond
- to_index(::Main.PlutoRunner.Bond)@indices.jl:300
- to_index(::Vector{Dict{Main.var"workspace#3".EX, HypertextLiteral.Result}}, ::Main.PlutoRunner.Bond)@indices.jl:277
- to_indices@indices.jl:333[inlined]
- to_indices@indices.jl:325[inlined]
- getindex@abstractarray.jl:1241[inlined]
- top-level scope@Local: 1[inlined]
x=2
y=7
We can also visualize what Julia does in the forward pass on the code itself:
(21, *)
├─ 3
└─ y=7
(2, *)
├─ 2
└─ (1, -)
├─ x=2
└─ 1
ad_steps (generic function with 1 method)to_json (generic function with 1 method)show_tree (generic function with 2 methods)@visual_debug (macro with 1 method)